__global__ void cumsumKernel(int b,int n,const float * __restrict__ inp,float * __restrict__ out){
	const int BlockSize=2048;
	const int paddingLevel=5;
	__shared__ float buffer4[BlockSize*4];
	__shared__ float buffer[BlockSize+(BlockSize>>paddingLevel)];
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		float runningsum=0,runningsum2=0;
		for (int j=0;j<n;j+=BlockSize*4){
			//int n2=min(n-j,BlockSize);
			/*for (int k=threadIdx.x;k<n2;k+=blockDim.x){
				buffer[k+(k>>paddingLevel)]=inp[i*n+j+k];
			}*/
			int n24_i=min(n-j,BlockSize*4);
			int n24=(n24_i+3)&~3;
			int n2=n24>>2;
			/*for (int k=threadIdx.x;k<n2;k+=blockDim.x){
				buffer[k+(k>>paddingLevel)]=inp[i*n+j+k];
			}*/
			for (int k=threadIdx.x*4;k<n24_i;k+=blockDim.x*4){
				if (k+3<n24_i){
					float v1=inp[i*n+j+k];
					float v2=inp[i*n+j+k+1];
					v2+=v1;
					float v3=inp[i*n+j+k+2];
					float v4=inp[i*n+j+k+3];
					v4+=v3;
					v3+=v2;
					v4+=v2;
					buffer4[k]=v1;
					buffer4[k+1]=v2;
					buffer4[k+2]=v3;
					buffer4[k+3]=v4;
					buffer[(k>>2)+(k>>(2+paddingLevel))]=v4;
				}else{
					float v=0;
					for (int k2=k;k2<n24_i;k2++){
						v+=inp[i*n+j+k2];
						buffer4[k2]=v;
					}
					for (int k2=n24_i;k2<n24;k2++){
						buffer4[k2]=v;
					}
					buffer[(k>>2)+(k>>(2+paddingLevel))]=v;
				}
			}
			int u=0;
			for (;(2<<u)<=n2;u++){
				__syncthreads();
				for (int k=threadIdx.x;k<int(n2>>(u+1));k+=blockDim.x){
					int i1=(((k<<1)+2)<<u)-1;
					int i2=(((k<<1)+1)<<u)-1;
					i1+=i1>>paddingLevel;
					i2+=i2>>paddingLevel;
					buffer[i1]+=buffer[i2];
				}
			}
			u--;
			for (;u>=0;u--){
				__syncthreads();
				for (int k=threadIdx.x;k<int((n2-(1<<u))>>(u+1));k+=blockDim.x){
					int i1=(((k<<1)+3)<<u)-1;
					int i2=(((k<<1)+2)<<u)-1;
					i1+=i1>>paddingLevel;
					i2+=i2>>paddingLevel;
					buffer[i1]+=buffer[i2];
				}
			}
			__syncthreads();
			/*for (int k=threadIdx.x;k<n2;k+=blockDim.x){
				out[i*n+j+k]=buffer[k+(k>>paddingLevel)]+runningsum;
			}*/
			for (int k=threadIdx.x*4;k<n24;k+=blockDim.x*4){
				if (k!=0){
					int k2=((k>>2)-1)+(((k>>2)-1)>>paddingLevel);
					buffer4[k]+=buffer[k2];
					buffer4[k+1]+=buffer[k2];
					buffer4[k+2]+=buffer[k2];
					buffer4[k+3]+=buffer[k2];
				}
			}
			__syncthreads();
			for (int k=threadIdx.x;k<n24_i;k+=blockDim.x){
				out[i*n+j+k]=buffer4[k]+runningsum;
			}
			//float t=buffer4[n24-1]+runningsum2;
			float t=buffer[(n2-1)+((n2-1)>>paddingLevel)]+runningsum2;
			float r2=runningsum+t;
			runningsum2=t-(r2-runningsum);
			runningsum=r2;
			__syncthreads();
		}
	}
}
__global__ void binarysearchKernel(int b,int n,int m,const float * __restrict__ dataset,const float * __restrict__ query, int * __restrict__ result){
	int base=1;
	while (base<n)
		base<<=1;
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		for (int j=blockIdx.y*blockDim.x+threadIdx.x;j<m;j+=blockDim.x*gridDim.y){
			float q=query[i*m+j]*dataset[i*n+n-1];
			int r=n-1;
			for (int k=base;k>=1;k>>=1)
				if (r>=k && dataset[i*n+r-k]>=q)
					r-=k;
			result[i*m+j]=r;
		}
	}
}
__global__ void farthestpointsamplingKernel(int b,int n,int m,const float * __restrict__ dataset,float * __restrict__ temp,int * __restrict__ idxs){
	if (m<=0)
		return;
	const int BlockSize=512;
	__shared__ float dists[BlockSize];
	__shared__ int dists_i[BlockSize];
	const int BufferSize=3072;
	__shared__ float buf[BufferSize*3];
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		int old=0;
		if (threadIdx.x==0)
			idxs[i*m+0]=old;
		for (int j=threadIdx.x;j<n;j+=blockDim.x){
			temp[blockIdx.x*n+j]=1e38;
		}
		for (int j=threadIdx.x;j<min(BufferSize,n)*3;j+=blockDim.x){
			buf[j]=dataset[i*n*3+j];
		}
		__syncthreads();
		for (int j=1;j<m;j++){
			int besti=0;
			float best=-1;
			float x1=dataset[i*n*3+old*3+0];
			float y1=dataset[i*n*3+old*3+1];
			float z1=dataset[i*n*3+old*3+2];
			for (int k=threadIdx.x;k<n;k+=blockDim.x){
				float td=temp[blockIdx.x*n+k];
				float x2,y2,z2;
				if (k<BufferSize){
					x2=buf[k*3+0];
					y2=buf[k*3+1];
					z2=buf[k*3+2];
				}else{
					x2=dataset[i*n*3+k*3+0];
					y2=dataset[i*n*3+k*3+1];
					z2=dataset[i*n*3+k*3+2];
				}
				float d=(x2-x1)*(x2-x1)+(y2-y1)*(y2-y1)+(z2-z1)*(z2-z1);
				float d2=min(d,td);
				if (d2!=td)
					temp[blockIdx.x*n+k]=d2;
				if (d2>best){
					best=d2;
					besti=k;
				}
			}
			dists[threadIdx.x]=best;
			dists_i[threadIdx.x]=besti;
			for (int u=0;(1<<u)<blockDim.x;u++){
				__syncthreads();
				if (threadIdx.x<(blockDim.x>>(u+1))){
					int i1=(threadIdx.x*2)<<u;
					int i2=(threadIdx.x*2+1)<<u;
					if (dists[i1]<dists[i2]){
						dists[i1]=dists[i2];
						dists_i[i1]=dists_i[i2];
					}
				}
			}
			__syncthreads();
			old=dists_i[0];
			if (threadIdx.x==0)
				idxs[i*m+j]=old;
		}
	}
}
__global__ void gatherpointKernel(int b,int n,int m,const float * __restrict__ inp,const int * __restrict__ idx,float * __restrict__ out){
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		for (int j=blockIdx.y*blockDim.x+threadIdx.x;j<m;j+=blockDim.x*gridDim.y){
			int a=idx[i*m+j];
			out[(i*m+j)*3+0]=inp[(i*n+a)*3+0];
			out[(i*m+j)*3+1]=inp[(i*n+a)*3+1];
			out[(i*m+j)*3+2]=inp[(i*n+a)*3+2];
		}
	}
}
__global__ void scatteraddpointKernel(int b,int n,int m,const float * __restrict__ out_g,const int * __restrict__ idx,float * __restrict__ inp_g){
	for (int i=blockIdx.x;i<b;i+=gridDim.x){
		for (int j=blockIdx.y*blockDim.x+threadIdx.x;j<m;j+=blockDim.x*gridDim.y){
			int a=idx[i*m+j];
			atomicAdd(&inp_g[(i*n+a)*3+0],out_g[(i*m+j)*3+0]);
			atomicAdd(&inp_g[(i*n+a)*3+1],out_g[(i*m+j)*3+1]);
			atomicAdd(&inp_g[(i*n+a)*3+2],out_g[(i*m+j)*3+2]);
		}
	}
}
void cumsumLauncher(int b,int n,const float * inp,float * out){
	cumsumKernel<<<32,512>>>(b,n,inp,out);
}
//require b*n working space
void probsampleLauncher(int b,int n,int m,const float * inp_p,const float * inp_r,float * temp,int * out){
	cumsumKernel<<<32,512>>>(b,n,inp_p,temp);
	binarysearchKernel<<<dim3(32,8,1),512>>>(b,n,m,temp,inp_r,out);
}
//require 32*n working space
void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out){
	farthestpointsamplingKernel<<<32,512>>>(b,n,m,inp,temp,out);
}
void gatherpointLauncher(int b,int n,int m,const float * inp,const int * idx,float * out){
	gatherpointKernel<<<dim3(2,8,1),512>>>(b,n,m,inp,idx,out);
}
void scatteraddpointLauncher(int b,int n,int m,const float * out_g,const int * idx,float * inp_g){
	scatteraddpointKernel<<<dim3(2,8,1),512>>>(b,n,m,out_g,idx,inp_g);
}

